1
從急切運算到基於區塊的平行運算
AI023Lesson 3
00:00

PyTorch 急切模式 轉換至 Triton 需要將張量視為單一整體物件的觀念,轉變為將其視為可分割、易於管理的 區塊 或稱為瓦片。

1. PyTorch 與 Triton 張量

必須清楚區分 Triton 張量PyTorch 張量。PyTorch 張量是 主機端的 Python 物件 包裝了形狀、資料類型、裝置、步幅及儲存元資料的物件。相較之下,Triton 使用的是特定記憶體區塊內的 原始資料指標 以進行更底層的優化。

2. 態急模式的瓶頸

在標準的急切執行中,每一個運算(例如加法後接 ReLU)都需要獨立啟動一個核心,並進行一次 全域記憶體往返。這正是現代 GPU 計算的主要瓶頸。Triton 透過在單一核心內融合多個運算來克服此問題,該核心會直接在晶片內部記憶體中處理資料區塊(例如 128、256 或 512 個元素)。 融合 在單一核心內處理資料區塊(例如 128、256 或 512 個元素),並直接在晶片記憶體中運作。

3. 區塊導向的範式

與傳統 CUDA 線程的標量級思維不同,Triton 改用 SPMD(單一程式,多重資料) 於區塊層級。您只需撰寫一個核心,Triton 便會在整個網格上啟動多個實例。每個實例利用其 program_id 來計算它所擁有的「資料區塊」是哪一部分。

PyTorch 張量[元資料封裝]區塊 0(pid 0)區塊 1(pid 1)區塊 2(pid 2)

4. 環境設定

開始前,請 在乾淨的環境中安裝 Triton (使用 Conda 或 venv)以確保不會與現有的 CUDA 工具包產生相依性衝突: pip install triton

main.py
TERMINALbash — 80x24
> Ready. Click "Run" to execute.
>